import os
import json
import torch
from tqdm import tqdm
import sys
import re
import base64
import cv2
from PIL import Image
from io import BytesIO
import random
import numpy as np
from unsloth import FastVisionModel
from transformers import TextStreamer
from transformers import AutoProcessor, AutoModelForVision2Seq
import pandas as pd
import argparse
import logging

logger = logging.getLogger(__name__)

def frame_to_data_url(frame_bgr):
    try:
        # Check if the frame is valid
        if frame_bgr is None or frame_bgr.size == 0:
            return None
        
        # Convert the BGR frame (OpenCV format) to RGB
        frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)

        # Convert the RGB frame to a PIL Image
        image = Image.fromarray(frame_rgb)
        image = image.resize((256, 256), Image.LANCZOS)
        # Create a BytesIO buffer to hold the image data
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        buffered.seek(0)

        # Encode the image data in base64
        base64_encoded_data = base64.b64encode(buffered.read()).decode('utf-8')

        # Construct the data URL
        return f"data:image/jpeg;base64,{base64_encoded_data}"
    except Exception as e:
        print(f"Error in frame_to_data_url: {e}")
        return None

# CLI arguments
parser = argparse.ArgumentParser()
parser.add_argument("--start", type=int, default=0, help="Starting sample index")
parser.add_argument("--end", type=int, default=None, help="Ending sample index")
parser.add_argument("--output_dir", type=str, default=".", help="Directory for result JSONs")
parser.add_argument("--similarity_json", type=str, default=None, help="Path to similarity JSON file")
args = parser.parse_args()
# Similarity data
if args.similarity_json:
    with open(args.similarity_json, "r") as f:
        SIMILARITY_DATA = json.load(f)
else:
    SIMILARITY_DATA = {}

# Load Qwen model
logger.info("Loading Qwen model...")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-72B-Instruct", trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
    "Qwen/Qwen2.5-VL-72B-Instruct",
    torch_dtype="auto",
    device_map="auto",
    load_in_4bit=True,
    trust_remote_code=True
)
logger.info("Model loaded successfully!")

# Single baseline system prompt
baseline_system_prompt = """You are an expert evaluator of website aesthetics and design. Your task is to assess how much people would like a website based on its visual design, layout, color scheme, typography, and overall aesthetic appeal.

You will be shown 5 example website screenshots with their likeability scores (on a 0-10 scale), followed by a new website screenshot that you need to evaluate.
You can provide precise scores including decimal values (e.g., 7.5, 8.2) to better reflect your nuanced judgment.
Return your response in this exact format:
Answer: [0-10] ← You must include this numerical score.
Reason: [Explain what aspects of the website design make it appealing or unappealing, considering layout, colors, typography, and overall aesthetic quality]
"""

def verbalize(prompt, sys_prompt, images):
    try:
        # Convert images to PIL format for Qwen
        pil_images = []
        for img_url, score in images:
            try:
                if img_url and img_url.startswith('data:image/jpeg;base64,'):
                    # Decode base64 image
                    base64_data = img_url.split(',')[1]
                    image_data = base64.b64decode(base64_data)
                    image = Image.open(BytesIO(image_data))
                    pil_images.append(image)
            except Exception as e:
                print(f"Error processing image in verbalize: {e}")
                continue
        
        if not pil_images:
            print("No valid images found for model inference")
            return "Error: No valid images"
        
        # Create the full instruction combining system prompt and user prompt
        full_instruction = f"{sys_prompt}\n\n{prompt}"
        
        # Use only the target image (last one) for inference with Qwen
        # Note: The target image is always the last one in the list
        target_image = pil_images[-1]
        
        messages = [
            {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": full_instruction}
            ]}
        ]
        
        input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
        inputs = processor(
            target_image,
            input_text,
            add_special_tokens=False,
            return_tensors="pt",
        ).to("cuda")
        
        # Generate response with error handling
        with torch.no_grad():
            outputs = model.generate(
                **inputs, 
                max_new_tokens=1028,  # Increased to allow full response with score
                use_cache=True, 
                temperature=0.85, 
                min_p=0.1,
                do_sample=True
            )
        
        # Decode the response
        response = processor.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
        return response.strip()
        
    except Exception as e:
        print(f"Error in verbalize function: {e}")
        return f"Error during model inference: {str(e)}"

import pandas as pd
import re

def safe_load_image(image_path):
    """Safely load an image with multiple fallback methods"""
    try:
        # First try with cv2
        image = cv2.imread(image_path)
        if image is not None and image.size > 0:
            return image
    except Exception as e:
        print(f"cv2.imread failed for {image_path}: {e}")
    
    try:
        # Fallback: try with PIL and convert to cv2 format
        pil_image = Image.open(image_path)
        pil_image = pil_image.convert('RGB')
        cv2_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
        return cv2_image
    except Exception as e:
        print(f"PIL fallback failed for {image_path}: {e}")
    
    return None

def extract_score_from_response(response):
    """Extracts a numerical score from the model's response."""
    answer = None
    
    # Pattern 1: Look for "Answer: X" format
    answer_pattern = re.search(r'Answer:\s*(\d+(?:\.\d+)?)', response, re.IGNORECASE)
    if answer_pattern:
        try:
            answer = float(answer_pattern.group(1))
            answer = max(0.0, min(10.0, answer))
        except ValueError:
            pass
    
    # Pattern 2: Look for score at the end of response
    if answer is None:
        score_pattern = re.search(r'(?:score|rating):\s*(\d+(?:\.\d+)?)', response, re.IGNORECASE)
        if score_pattern:
            try:
                answer = float(score_pattern.group(1))
                answer = max(0.0, min(10.0, answer))
            except ValueError:
                pass
    
    # Pattern 3: Fallback - any number at the end (last resort)
    if answer is None:
        number_matches = re.findall(r'\b(\d+(?:\.\d+)?)\b', response)
        # Filter numbers that could be scores (0-10 range)
        valid_scores = [float(n) for n in number_matches if 0 <= float(n) <= 10]
        if valid_scores:
            answer = valid_scores[-1]  # Take the last valid score
    
    # Debug output for failed extractions
    if answer is None:
        print(f"⚠️ Could not extract score from response")
        print(f"Response length: {len(response)} characters")
        print(f"Response ends with: '...{response[-50:]}'")
        
    return answer

def prepare_and_run_evaluation():
    test_filename = "website-aesthetics-datasets/rating-based-dataset/preprocess/test_list.csv"
    df = pd.read_csv(test_filename)
    # Determine slice
    indices = list(range(args.start, args.end + 1 if args.end is not None else len(df)))
    response_dict = []
    os.makedirs(args.output_dir, exist_ok=True)
    for i in tqdm(indices, desc="Processing Samples"):
        try:
            d = df.iloc[i]
            value = d.to_dict()
            image_path = 'website-aesthetics-datasets/rating-based-dataset/images/'+d['image'].replace('_resized','')
            
            # Check if image file exists
            if not os.path.exists(image_path):
                print(f"Warning: Image file does not exist {image_path}, skipping...")
                continue
            
            # Use safe image loading
            image = safe_load_image(image_path)
            
            # Check if main image loaded successfully
            if image is None:
                print(f"Warning: Could not load image {image_path}, skipping...")
                continue
                
            image_url = frame_to_data_url(image)
            if image_url is None:
                print(f"Warning: Could not process image {image_path}, skipping...")
                continue
            
            # Sample more images than needed to account for potential failures
            other_indices = list(range(df.shape[0]))
            other_indices.remove(i)
            # Sample up to 25 images to try to get 5 valid examples, but continue with whatever we get  
            sample_size = min(25, len(other_indices))
            sample_indices = random.sample(other_indices, sample_size) if other_indices else []
            example_lines = []
            example_images = []
            valid_examples = 0
            
            # Try to get up to 5 valid example images, but don't fail if we get fewer
            for idx in sample_indices:
                if valid_examples >= 5:
                    break
                    
                try:
                    row = df.iloc[idx]
                    fname = row['image']
                    score = row['mean_score']
                    img_path = 'website-aesthetics-datasets/rating-based-dataset/images/'+fname.replace('_resized','')
                    
                    # Check if example image file exists
                    if not os.path.exists(img_path):
                        continue
                        
                    # Use safe image loading
                    img = safe_load_image(img_path)
                    
                    # Check if example image loaded successfully
                    if img is None:
                        continue
                        
                    img_url = frame_to_data_url(img)
                    if img_url is None:
                        continue
                        
                    example_lines.append(f"Score: {score:.1f}")
                    example_images.append((img_url, score))
                    valid_examples += 1
                    
                except Exception as e:
                    print(f"Error processing example image {idx}: {e}")
                    continue
            
            # Continue even with few or no valid example images - only log the count
            if valid_examples == 0:
                print(f"Info: No valid example images for sample {i}, proceeding with target image only...")
            elif valid_examples < 3:
                print(f"Info: Only {valid_examples} valid example images for sample {i}, proceeding anyway...")
            
            # Add the current image as the last one
            example_images.append((image_url, None))
            examples_text = "\n".join(example_lines)
            
            # Create the user prompt based on whether we have examples
            if valid_examples > 0:
                prompt = f"""Given the images below, the first {valid_examples} are example website screenshots with their likeability scores (on a 0-10 scale, see the list below). The last image is the one you should score. 

Carefully analyze the last website screenshot and provide a score between 0 to 10 based on how much people would like the website's visual design, layout, colors, typography, and overall aesthetic appeal.

Here are {valid_examples} example likeability scores (in order):
{examples_text}

Please evaluate the final website screenshot and provide your assessment."""

            # Get prediction
            resp = verbalize(prompt, baseline_system_prompt, example_images)
            answer = extract_score_from_response(resp) # (Need to add extract_score_from_response)
            # Store results
            value.update({
                "baseline_response": {"prediction": answer, "reason": resp},
                "no_persona_prediction": answer
            })
            response_dict.append(value)
            # Incremental save
            output_filename = os.path.join(args.output_dir, f'results_qwen_nopersona_slice_{args.start}_{args.end if args.end is not None else "end"}.json')
            with open(output_filename, 'w') as f:
                json.dump(response_dict, f, indent=4)
        except Exception as e:
            print(f"Error on sample {i}: {e}")
            continue
    print(f"Finished. Results in {output_filename}")
if __name__ == "__main__":
    prepare_and_run_evaluation()